library(here)
library(tidyverse)
library(scales)
library(redist)
library(sf)
library(patchwork)
library(mgcv)

DEM = "#0064B0"
GOP = "#A0442C"
source(here("R/00_custom_functions.R"))

# Load data and merge -----------------------------------------------------

vtd_das19 = read_rds(here("data/PA/pa_vtd_das_19.rds"))
vtd_das12 = read_rds(here("data/PA/pa_vtd_das_12.rds")) %>%
    rename_with(~ str_c(., "_das12"), -precinct)
vtd_das4 = read_rds(here("data/PA/pa_vtd_das_04.rds")) %>%
    rename_with(~ str_c(., "_das4"), -precinct)
vtd_orig = read_rds(here("data/PA/pa_vtd_orig.rds"))
pa_shp = read_rds(here("data/PA/pa_shp.rds")) %>%
    sf::st_as_sf()
vtds = inner_join(pa_shp, vtd_orig, by="precinct") %>%
    inner_join(vtd_das19, by="precinct", suffix=c("_orig", "_das19")) %>%
    inner_join(vtd_das12, by="precinct") %>%
    inner_join(vtd_das4, by="precinct") %>%
    mutate(hectares = as.numeric(sf::st_area(pa_shp))/1e4)


# Simulate ----------------------------------------------------------------
adj = redist.adjacency(st_make_valid(vtds))

pa_orig = redist_map(vtds, existing_plan=cd, pop_tol=0.001, total_pop=pop_orig, adj=adj)
pa_das19 = redist_map(vtds, existing_plan=cd, pop_tol=0.001, total_pop=pop_das19, adj=adj)
pa_das12 = redist_map(vtds, existing_plan=cd, pop_tol=0.001, total_pop=pop_das12, adj=adj)
pa_das4 = redist_map(vtds, existing_plan=cd, pop_tol=0.001, total_pop=pop_das4, adj=adj)

orig_rds = here("data/PA/sim_orig_cty_001_10k.rds")
if (!file.exists(orig_rds)) {
    plans_orig = redist_smc(pa_orig, 10000, counties=county)
    write_rds(plans_orig, orig_rds, compress="xz")
} else plans_orig = read_rds(orig_rds)

das_rds = here("data/PA/sim_das19_cty_001_10k.rds")
if (!file.exists(das_rds)) {
    plans_das19 = redist_smc(pa_das19, 10000, counties=county)
    write_rds(plans_das19, das_rds, compress="xz")
} else plans_das19 = read_rds(das_rds)

das_rds = here("data/PA/sim_das12_cty_001_10k.rds")
if (!file.exists(das_rds)) {
    plans_das12 = redist_smc(pa_das12, 10000, counties=county)
    write_rds(plans_das12, das_rds, compress="xz")
} else plans_das12 = read_rds(das_rds)

das_rds = here("data/PA/sim_das04_cty_001_10k.rds")
if (!file.exists(das_rds)) {
    plans_das4 = redist_smc(pa_das4, 10000, counties=county)
    write_rds(plans_das4, das_rds, compress="xz")
} else plans_das4 = read_rds(das_rds)

plans_orig = rename(plans_orig, pop_orig=total_pop)
plans_das19 = rename(plans_das19, pop_das19=total_pop)
plans_das12 = rename(plans_das12, pop_das12=total_pop)
plans_das4 = rename(plans_das4, pop_das4=total_pop)

plans_orig$pop_das19 = as.integer(redist:::pop_tally(get_plans_matrix(plans_orig),
                                                   vtds$pop_das19, attr(pa_orig, "ndists")))
plans_orig$pop_das12 = as.integer(redist:::pop_tally(get_plans_matrix(plans_orig),
                                                   vtds$pop_das12, attr(pa_orig, "ndists")))
plans_orig$pop_das4 = as.integer(redist:::pop_tally(get_plans_matrix(plans_orig),
                                                   vtds$pop_das4, attr(pa_orig, "ndists")))

plans_das19$pop_das12 = as.integer(redist:::pop_tally(get_plans_matrix(plans_das19),
                                                   vtds$pop_das12, attr(pa_orig, "ndists")))
plans_das19$pop_das4 = as.integer(redist:::pop_tally(get_plans_matrix(plans_das19),
                                                   vtds$pop_das4, attr(pa_orig, "ndists")))
plans_das19$pop_orig = as.integer(redist:::pop_tally(get_plans_matrix(plans_das19),
                                                   vtds$pop_orig, attr(pa_orig, "ndists")))

plans_das12$pop_das19 = as.integer(redist:::pop_tally(get_plans_matrix(plans_das12),
                                                   vtds$pop_das19, attr(pa_orig, "ndists")))
plans_das12$pop_das4 = as.integer(redist:::pop_tally(get_plans_matrix(plans_das12),
                                                   vtds$pop_das4, attr(pa_orig, "ndists")))
plans_das12$pop_orig = as.integer(redist:::pop_tally(get_plans_matrix(plans_das12),
                                                   vtds$pop_orig, attr(pa_orig, "ndists")))

plans_das4$pop_das19 = as.integer(redist:::pop_tally(get_plans_matrix(plans_das4),
                                                   vtds$pop_das19, attr(pa_orig, "ndists")))
plans_das4$pop_das12 = as.integer(redist:::pop_tally(get_plans_matrix(plans_das4),
                                                   vtds$pop_das12, attr(pa_orig, "ndists")))
plans_das4$pop_orig = as.integer(redist:::pop_tally(get_plans_matrix(plans_das4),
                                                   vtds$pop_orig, attr(pa_orig, "ndists")))

run_check = FALSE
if (run_check) {
    # check weights distribution
    qqplot(get_plans_weights(plans_orig), get_plans_weights(plans_das12))
    # check sample diversity
    pm_orig = plan_distances(filter(plans_orig, as.integer(draw) <= 50), ncores=8)
    pm_das = plan_distances(filter(plans_das12, as.integer(draw) <= 50), ncores=8)
    summary(as.numeric(pm_orig))
    summary(as.numeric(pm_das12))
}

# Analyze -----------------------------------------------------------------

n_eff = c(orig = attr(plans_orig, "n_eff"),
          das19 = attr(plans_das19, "n_eff"),
          das12 = attr(plans_das12, "n_eff"),
          das4 = attr(plans_das4, "n_eff"))


## Population deviation -----

calc_pop = function(pl) {
    tgt_pop = get_target(pa_orig)
    pl %>%
        subset_sampled() %>%
        mutate(dev_orig = abs(pop_orig/tgt_pop - 1),
               dev_das19 = abs(pop_das19/tgt_pop - 1),
               dev_das12 = abs(pop_das12/tgt_pop - 1),
               dev_das4 = abs(pop_das4/tgt_pop - 1)) %>%
        group_by(draw) %>%
        summarize(dev_orig = max(dev_orig),
                  dev_das19 = max(dev_das19),
                  dev_das12 = max(dev_das12),
                  dev_das4 = max(dev_das4)) %>%
        pivot_longer(c(dev_orig, dev_das19, dev_das12, dev_das4),
                     names_to="source", names_prefix="dev_",
                     values_to="dev")
}

pop_res = bind_rows(
    orig=calc_pop(plans_orig),
    das19=calc_pop(plans_das19),
    das12=calc_pop(plans_das12),
    das4=calc_pop(plans_das4),
    .id="sampled_from")
pop_res %>%
    group_by(from=sampled_from, source) %>%
    summarize(p=mean(dev >= 0.001)) %>%
write_rds(here("data/PA/pa_exc.rds"), compress="xz")

popdevs <- readRDS(here("data/PA/pa_parity.rds")) %>%
    rename(`Sampled from` = names)
popdevs <- popdevs %>% mutate(source = `Sampled from`)
popdevs$`Sampled from` <- 'Census 2010'
popdevs$source[popdevs$source == 'Census'] <- 'Census 2010'


show_src <- c("orig", "das12", "das4")

pop_res %>%
    filter(source %in% show_src, sampled_from %in% show_src) %>%
    mutate(sampled_from = lbl_source(sampled_from),
           source = lbl_source(source)) %>%
    rename(`Sampled from` = sampled_from) %>%
    drop_na() %>%
ggplot(aes(dev, fill=source)) +
    facet_wrap(~ `Sampled from`, labeller=label_both) +
    geom_histogram(aes(y=..count../sum(..count..)), position="dodge",
                   binwidth=5e-4, boundary=5e-4, alpha=1) +
    geom_vline(data = popdevs, aes(xintercept = parity, color = source),
               lty = "dashed", lwd = .3) +
    scale_x_continuous("Maximum population deviation",
                       labels = percent_format(accuracy = 0.1),
                       breaks = seq(0, 0.006, by = 0.002),
                       expand = expansion(add = c(0, 0.0005))) +
    scale_y_continuous("Fraction of plans", labels = percent_format(accuracy = 1),
                       expand = expansion(mult = c(0, 0.05))) +
    geom_text(data = popdevs,
             aes(x = parity + 0.0012, y = 0.105, label = paste0("Enacted:\n  ", source),
                 color = source),
             size = 2, family = "Times") +
    coord_cartesian(xlim=c(0, NA), ylim = c(0, NA)) +
    theme_ppmf() +
    labs(fill="Evaluation\ndata source") +
    scale_fill_manual(values=PAL_DAS) +
    scale_color_manual(values=PAL_DAS) +
    guides(color = 'none')
ggsave(here("figs/pa_sim_parity.pdf"), width=6.5, height=2)

## Democratic seats -----

calc_dem_seats = function(pl) {
    pl %>%
        mutate(dem = group_frac(pa_orig, ndv, ndv+nrv)) %>%
        group_by(draw) %>%
        summarize(n_dem = sum(dem > 0.5))
}
seats_res = bind_rows(
    orig=calc_dem_seats(plans_orig),
    das12=calc_dem_seats(plans_das12),
    das4=calc_dem_seats(plans_das4),
    .id="sampled_from") %>%
    as_tibble()
write_rds(seats_res, here("data/pa_n_dem.rds"))

seats_res %>%
    filter(!(draw %in% c("gov", "resp", "court", "cd"))) %>%
    mutate(sampled_from = lbl_source(sampled_from)) %>%
    drop_na() %>%
ggplot(aes(n_dem, fill=sampled_from)) +
    geom_histogram(aes(y=..count../sum(..count..)), position="dodge",
                   binwidth=0.5, alpha=1) +
    geom_vline(aes(xintercept=n_dem), lty="dashed",
               data=filter(seats_res, draw %in% c("resp"))) +
    scale_y_continuous("Fraction of plans", labels=function(x) percent(x, 1),
                       expand=expansion(mult=c(0, 0.05))) +
    theme_ppmf() +
    labs(x="Democratic seats", fill="Plans\nsampled from") +
    scale_fill_manual(values=PAL_DAS)
ggsave(here("figs/pa_n_dem.pdf"), width=6.5, height=2)

## Democratic mean-median -----

calc_dem_mm = function(pl) {
    pl %>%
        mutate(dem = group_frac(pa_orig, ndv, ndv+nrv)) %>%
        group_by(draw) %>%
        summarize(median_mean = median(dem) - mean(dem))
}
mm_res = bind_rows(
    orig=calc_dem_mm(plans_orig),
    das12=calc_dem_mm(plans_das12),
    das4=calc_dem_mm(plans_das4),
    .id="sampled_from") %>%
    as_tibble()

mm_res %>%
    filter(!(draw %in% c("gov", "resp", "court", "cd"))) %>%
    group_by(sampled_from) %>%
    summarize(mmm = mean(median_mean),
              sd = sd(median_mean),
              se = sd / sqrt(n_eff[sampled_from[1]]),
              lower = mmm - 3*se,
              upper = mmm + 3*se)

mm_res %>%
    group_by(sampled_from) %>%
    mutate(y=cume_dist(median_mean)) %>%
ggplot(aes(x=median_mean,
           fill=lbl_source(sampled_from))) +
    geom_vline(xintercept=0, lty="dashed") +
    #geom_density(size=1, bw=0.0015) +
    geom_histogram(binwidth=0.0025, position="dodge") +
    scale_fill_manual(values=PAL_DAS) +
    scale_x_continuous(labels=function(x) percent(x, 1)) +
    scale_y_continuous(NULL, expand=expansion(mult=c(0, 0.05))) +
    labs(x="Median - mean Democratic share",
         color="Plans\nsampled from") +
    theme_ppmf()


## Turnout/Partisanship ----------------------------------------------------

plot_error_distr = function(pl, qty, title) {
    pl %>%
        subset_sampled() %>%
        mutate(dem = group_frac(pa_orig, ndv, ndv+nrv)) %>%
        number_by(dem) %>%
        as_tibble() %>%
    ggplot(aes(dem, {{ qty }} - pop_orig, group=district)) +
        geom_point(size=0.1, alpha=0.1) +
        geom_smooth(aes(color=district), method=lm, formula=y~x, se=F) +
        labs(title=title) +
        wacolors::scale_color_wa_c("sea_star") +
        theme_ppmf()
}
plot_error_distr(plans_das12, pop_das12, "DAS 12.2") +
    plot_error_distr(plans_das4, pop_das4, "DAS 4.5") +
    plot_layout(guides="collect")
ggsave(here("figs/pa_error_distr.png"), width=9, height=4, dpi=400)

## District-level race and partisan effects --------------------------------

calc_grp_pct = function(pl, type="partisan") {
    if (type == "partisan")
        pl = mutate(pl, grp = group_frac(pa_orig, ndv, ndv+nrv))
    else
        pl = mutate(pl, grp = group_frac(pa_orig, vap_orig - vap_white_orig, vap_orig))

    pl %>%
        number_by(grp) %>%
        as_tibble()
}

dem_pcts = bind_rows(orig=calc_grp_pct(plans_orig, "partisan"),
                     das12=calc_grp_pct(plans_das12, "partisan"),
                     das4=calc_grp_pct(plans_das4, "partisan"),
                     .id="source")
min_pcts = bind_rows(orig=calc_grp_pct(plans_orig, "race"),
                     das12=calc_grp_pct(plans_das12, "race"),
                     das4=calc_grp_pct(plans_das4, "race"),
                     .id="source")

plot_grp_norm = function(pcts, xlab) {
    pcts %>%
        group_by(district) %>%
        mutate(ref_grp = grp[draw=="cd" & source=="orig"]) %>%
        filter(draw != "cd") %>%
    ggplot(aes(as.factor(district), grp - ref_grp, fill=lbl_source(source))) +
        geom_hline(yintercept=0, lty="dashed", color="#444444") +
        geom_boxplot(outlier.size=0.1) +
        scale_fill_manual(values=PAL_DAS) +
        scale_y_continuous("Difference from enacted plan",
                           labels=percent) +
        labs(x=xlab, fill="Plans\nsampled from") +
        theme_ppmf()
}

plot_grp_norm(dem_pcts, "Districts, ordered by Democratic two-party vote share")
plot_grp_norm(min_pcts, "Districts, ordered by minority vote share")

calc_grp_diffs = function(pcts) {
    pcts %>%
        group_by(district) %>%
        mutate(ref_grp = grp[draw=="cd" & source=="orig"]) %>%
        filter(draw != "cd") %>%
        group_by(district, source) %>%
        #summarize(mean(dem))
        summarize(mean_grp = mean(grp),
                  diff = mean(grp - ref_grp),
                  se = sd(grp - ref_grp) / sqrt(n_eff[[source[1]]])) %>%
        filter(source != "das4") %>%
        pivot_wider(names_from=source, values_from=c(diff, se, mean_grp)) %>%
        mutate(diff = diff_das12 - diff_orig,
               se = sqrt(se_das12^2 + se_orig^2),
               t = abs(diff) / se) %>%
        select(district, grp=mean_grp_orig, diff, se, t) %>%
        arrange(desc(t))
}

calc_grp_diffs(dem_pcts)
calc_grp_diffs(min_pcts)
